# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import logging
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn

from gluonts.core.component import validated
from gluonts.model import Input, InputSpec
from gluonts.torch.distributions import Output
from gluonts.torch.distributions.quantile_output import QuantileOutput
from gluonts.torch.scaler import StdScaler
from gluonts.torch.util import weighted_average

from .layers import (
    FeatureEmbedder,
    FeatureProjector,
    GatedResidualNetwork,
    TemporalFusionDecoder,
    TemporalFusionEncoder,
    VariableSelectionNetwork,
)

logger = logging.getLogger(__name__)


class TemporalFusionTransformerModel(nn.Module):
    """
    Temporal Fusion Transformer neural network.

    Partially based on the implementation in github.com/kashif/pytorch-transformer-ts.

    Inputs feat_static_real, feat_static_cat and feat_dynamic_real are mandatory.
    Inputs feat_dynamic_cat, past_feat_dynamic_real and past_feat_dynamic_cat are optional.
    """

    @validated()
    def __init__(
        self,
        context_length: int,
        prediction_length: int,
        d_feat_static_real: Optional[List[int]] = None,  #  Defaults to [1]
        c_feat_static_cat: Optional[List[int]] = None,  #  Defaults to [1]
        d_feat_dynamic_real: Optional[List[int]] = None,  # Defaults to [1]
        c_feat_dynamic_cat: Optional[List[int]] = None,  # Defaults to []
        d_past_feat_dynamic_real: Optional[List[int]] = None,  # Defaults to []
        c_past_feat_dynamic_cat: Optional[List[int]] = None,  # Defaults to []
        num_heads: int = 4,
        d_hidden: int = 32,
        d_var: int = 32,
        dropout_rate: float = 0.1,
        distr_output: Optional[Output] = None,
    ):
        super().__init__()

        self.context_length = context_length
        self.prediction_length = prediction_length
        self.num_heads = num_heads
        self.d_hidden = d_hidden
        self.d_var = d_var
        self.dropout_rate = dropout_rate

        if distr_output is None:
            distr_output = QuantileOutput(
                [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
            )
        self.distr_output = distr_output
        self.args_proj = self.distr_output.get_args_proj(
            in_features=self.d_hidden
        )

        self.d_feat_static_real = d_feat_static_real or [1]
        self.d_feat_dynamic_real = d_feat_dynamic_real or [1]
        self.d_past_feat_dynamic_real = d_past_feat_dynamic_real or []
        self.c_feat_static_cat = c_feat_static_cat or [1]
        self.c_feat_dynamic_cat = c_feat_dynamic_cat or []
        self.c_past_feat_dynamic_cat = c_past_feat_dynamic_cat or []

        self.num_feat_static = len(self.d_feat_static_real) + len(
            self.c_feat_static_cat
        )
        self.num_feat_dynamic = len(self.d_feat_dynamic_real) + len(
            self.c_feat_dynamic_cat
        )
        self.num_past_feat_dynamic = len(self.d_past_feat_dynamic_real) + len(
            self.c_past_feat_dynamic_cat
        )

        self.scaler = StdScaler(dim=1, keepdim=True)

        self.target_proj = nn.Linear(in_features=1, out_features=self.d_var)
        # Past-only dynamic features
        if self.d_past_feat_dynamic_real:
            self.past_feat_dynamic_proj: Optional[FeatureProjector] = (
                FeatureProjector(
                    feature_dims=self.d_past_feat_dynamic_real,
                    embedding_dims=[self.d_var]
                    * len(self.d_past_feat_dynamic_real),
                )
            )
        else:
            self.past_feat_dynamic_proj = None

        if self.c_past_feat_dynamic_cat:
            self.past_feat_dynamic_embed: Optional[FeatureEmbedder] = (
                FeatureEmbedder(
                    cardinalities=self.c_past_feat_dynamic_cat,
                    embedding_dims=[self.d_var]
                    * len(self.c_past_feat_dynamic_cat),
                )
            )
        else:
            self.past_feat_dynamic_embed = None

        # Known dynamic features
        if self.d_feat_dynamic_real:
            self.feat_dynamic_proj: Optional[FeatureProjector] = (
                FeatureProjector(
                    feature_dims=self.d_feat_dynamic_real,
                    embedding_dims=[self.d_var]
                    * len(self.d_feat_dynamic_real),
                )
            )
        else:
            self.feat_dynamic_proj = None

        if self.c_feat_dynamic_cat:
            self.feat_dynamic_embed: Optional[FeatureEmbedder] = (
                FeatureEmbedder(
                    cardinalities=self.c_feat_dynamic_cat,
                    embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat),
                )
            )
        else:
            self.feat_dynamic_embed = None

        # Static features
        if self.d_feat_static_real:
            self.feat_static_proj: Optional[FeatureProjector] = (
                FeatureProjector(
                    feature_dims=self.d_feat_static_real,
                    embedding_dims=[self.d_var] * len(self.d_feat_static_real),
                )
            )
        else:
            self.feat_static_proj = None

        if self.c_feat_static_cat:
            self.feat_static_embed: Optional[FeatureEmbedder] = (
                FeatureEmbedder(
                    cardinalities=self.c_feat_static_cat,
                    embedding_dims=[self.d_var] * len(self.c_feat_static_cat),
                )
            )
        else:
            self.feat_static_embed = None

        self.static_selector = VariableSelectionNetwork(
            d_hidden=self.d_var,
            num_vars=self.num_feat_static,
            dropout=self.dropout_rate,
        )
        self.ctx_selector = VariableSelectionNetwork(
            d_hidden=self.d_var,
            num_vars=self.num_past_feat_dynamic + self.num_feat_dynamic + 1,
            add_static=True,
            dropout=self.dropout_rate,
        )
        self.tgt_selector = VariableSelectionNetwork(
            d_hidden=self.d_var,
            num_vars=self.num_feat_dynamic,
            add_static=True,
            dropout=self.dropout_rate,
        )
        self.selection = GatedResidualNetwork(
            d_hidden=self.d_var,
            dropout=self.dropout_rate,
        )
        self.enrichment = GatedResidualNetwork(
            d_hidden=self.d_var,
            dropout=self.dropout_rate,
        )
        self.state_h = GatedResidualNetwork(
            d_hidden=self.d_var,
            d_output=self.d_hidden,
            dropout=self.dropout_rate,
        )
        self.state_c = GatedResidualNetwork(
            d_hidden=self.d_var,
            d_output=self.d_hidden,
            dropout=self.dropout_rate,
        )
        self.temporal_encoder = TemporalFusionEncoder(
            d_input=self.d_var,
            d_hidden=self.d_hidden,
        )
        self.temporal_decoder = TemporalFusionDecoder(
            context_length=self.context_length,
            prediction_length=self.prediction_length,
            d_hidden=self.d_hidden,
            d_var=self.d_var,
            num_heads=self.num_heads,
            dropout=self.dropout_rate,
        )

    def describe_inputs(self, batch_size=1) -> InputSpec:
        return InputSpec(
            {
                "past_target": Input(
                    shape=(batch_size, self.context_length), dtype=torch.float
                ),
                "past_observed_values": Input(
                    shape=(batch_size, self.context_length), dtype=torch.float
                ),
                "feat_static_real": Input(
                    shape=(batch_size, sum(self.d_feat_static_real)),
                    dtype=torch.float,
                ),
                "feat_static_cat": Input(
                    shape=(batch_size, len(self.c_feat_static_cat)),
                    dtype=torch.long,
                ),
                "feat_dynamic_real": Input(
                    shape=(
                        batch_size,
                        self.context_length + self.prediction_length,
                        sum(self.d_feat_dynamic_real),
                    ),
                    dtype=torch.float,
                ),
                "feat_dynamic_cat": Input(
                    shape=(
                        batch_size,
                        self.context_length + self.prediction_length,
                        len(self.c_feat_dynamic_cat),
                    ),
                    dtype=torch.long,
                ),
                "past_feat_dynamic_real": Input(
                    shape=(
                        batch_size,
                        self.context_length,
                        sum(self.d_past_feat_dynamic_real),
                    ),
                    dtype=torch.float,
                ),
                "past_feat_dynamic_cat": Input(
                    shape=(
                        batch_size,
                        self.context_length,
                        len(self.c_past_feat_dynamic_cat),
                    ),
                    dtype=torch.long,
                ),
            },
            torch.zeros,
        )

    def input_types(self) -> Dict[str, torch.dtype]:
        return {
            "past_target": torch.float,
            "past_observed_values": torch.float,
            "feat_static_real": torch.float,
            "feat_static_cat": torch.long,
            "feat_dynamic_real": torch.float,
            "feat_dynamic_cat": torch.long,
            "past_feat_dynamic_real": torch.float,
            "past_feat_dynamic_cat": torch.long,
        }

    def _preprocess(
        self,
        past_target: torch.Tensor,  # [N, T]
        past_observed_values: torch.Tensor,  # [N, T]
        feat_static_real: Optional[torch.Tensor],  # [N, D_sr]
        feat_static_cat: Optional[torch.Tensor],  # [N, D_sc]
        feat_dynamic_real: Optional[torch.Tensor],  # [N, T + H, D_dr]
        feat_dynamic_cat: Optional[torch.Tensor],  # [N, T + H, D_dc]
        past_feat_dynamic_real: Optional[torch.Tensor],  # [N, T, D_pr]
        past_feat_dynamic_cat: Optional[torch.Tensor],  # [N, T, D_pc]
    ) -> Tuple[
        List[torch.Tensor],
        List[torch.Tensor],
        List[torch.Tensor],
        torch.Tensor,
        torch.Tensor,
    ]:
        past_target, loc, scale = self.scaler(
            data=past_target, weights=past_observed_values
        )

        past_covariates = [self.target_proj(past_target.unsqueeze(-1))]
        future_covariates = []
        static_covariates = []
        if self.past_feat_dynamic_proj is not None:
            projs = self.past_feat_dynamic_proj(past_feat_dynamic_real)
            past_covariates.extend(projs)
        if self.past_feat_dynamic_embed is not None:
            embs = self.past_feat_dynamic_embed(past_feat_dynamic_cat)
            past_covariates.extend(embs)

        if self.feat_dynamic_proj is not None:
            projs = self.feat_dynamic_proj(feat_dynamic_real)
            for proj in projs:
                ctx_proj = proj[..., : self.context_length, :]
                tgt_proj = proj[..., self.context_length :, :]
                past_covariates.append(ctx_proj)
                future_covariates.append(tgt_proj)
        if self.feat_dynamic_embed is not None:
            embs = self.feat_dynamic_embed(feat_dynamic_cat)
            for emb in embs:
                ctx_emb = emb[..., : self.context_length, :]
                tgt_emb = emb[..., self.context_length :, :]
                past_covariates.append(ctx_emb)
                future_covariates.append(tgt_emb)

        if self.feat_static_proj is not None:
            projs = self.feat_static_proj(feat_static_real)
            static_covariates.extend(projs)
        if self.feat_static_embed is not None:
            embs = self.feat_static_embed(feat_static_cat)
            static_covariates.extend(embs)

        return (
            past_covariates,
            future_covariates,
            static_covariates,
            loc,
            scale,
        )

    def forward(
        self,
        past_target: torch.Tensor,  # [N, T]
        past_observed_values: torch.Tensor,  # [N, T]
        feat_static_real: Optional[torch.Tensor],  # [N, D_sr]
        feat_static_cat: Optional[torch.Tensor],  # [N, D_sc]
        feat_dynamic_real: Optional[torch.Tensor],  # [N, T + H, D_dr]
        feat_dynamic_cat: Optional[torch.Tensor] = None,  # [N, T + H, D_dc]
        past_feat_dynamic_real: Optional[torch.Tensor] = None,  # [N, T, D_pr]
        past_feat_dynamic_cat: Optional[torch.Tensor] = None,  # [N, T, D_pc]
    ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]:
        (
            past_covariates,  # [[N, T, d_var], ...]
            future_covariates,  # [[N, H, d_var], ...]
            static_covariates,  # [[N, d_var], ...]
            loc,  # [N, 1]
            scale,  # [N, 1]
        ) = self._preprocess(
            past_target=past_target,
            past_observed_values=past_observed_values,
            feat_static_real=feat_static_real,
            feat_static_cat=feat_static_cat,
            feat_dynamic_real=feat_dynamic_real,
            feat_dynamic_cat=feat_dynamic_cat,
            past_feat_dynamic_real=past_feat_dynamic_real,
            past_feat_dynamic_cat=past_feat_dynamic_cat,
        )

        static_var, _ = self.static_selector(static_covariates)  # [N, d_var]
        c_selection = self.selection(static_var).unsqueeze(1)  # [N, 1, d_var]
        c_enrichment = self.enrichment(static_var).unsqueeze(1)
        c_h = self.state_h(static_var)  # [N, self.d_hidden]
        c_c = self.state_c(static_var)  # [N, self.d_hidden]
        states = [c_h.unsqueeze(0), c_c.unsqueeze(0)]

        ctx_input, _ = self.ctx_selector(
            past_covariates, c_selection
        )  # [N, T, d_var]
        tgt_input, _ = self.tgt_selector(
            future_covariates, c_selection
        )  # [N, H, d_var]

        encoding = self.temporal_encoder(
            ctx_input, tgt_input, states
        )  # [N, T + H, d_hidden]
        decoding = self.temporal_decoder(
            encoding, c_enrichment, past_observed_values
        )  # [N, H, d_hidden]
        distr_args = self.args_proj(decoding)
        return distr_args, loc, scale

    def loss(
        self,
        past_target: torch.Tensor,  # [N, T]
        past_observed_values: torch.Tensor,  # [N, T]
        future_target: torch.Tensor,  # [N, H]
        future_observed_values: torch.Tensor,  # [N, H]
        feat_static_real: torch.Tensor,  # [N, D_sr]
        feat_static_cat: torch.Tensor,  # [N, D_sc]
        feat_dynamic_real: torch.Tensor,  # [N, T + H, D_dr]
        feat_dynamic_cat: Optional[torch.Tensor] = None,  # [N, T + H, D_dc]
        past_feat_dynamic_real: Optional[torch.Tensor] = None,  # [N, T, D_pr]
        past_feat_dynamic_cat: Optional[torch.Tensor] = None,  # [N, T, D_pc]
    ) -> torch.Tensor:
        distr_args, loc, scale = self(
            past_target=past_target,
            past_observed_values=past_observed_values,
            feat_static_real=feat_static_real,
            feat_static_cat=feat_static_cat,
            feat_dynamic_real=feat_dynamic_real,
            feat_dynamic_cat=feat_dynamic_cat,
            past_feat_dynamic_real=past_feat_dynamic_real,
            past_feat_dynamic_cat=past_feat_dynamic_cat,
        )  # [N, Q, T]
        loss = self.distr_output.loss(
            target=future_target,
            distr_args=distr_args,
            loc=loc,
            scale=scale,
        )
        return weighted_average(loss, weights=future_observed_values, dim=-1)
